MaxPoolGrad
描述 MaxPool 的反向传播(梯度)计算。该算子将上游梯度(dy)只回传到前向最大池化过程中被选为最大值的位置;其它位置的梯度为 0。
数学定义:
\[\begin{split}\text{output}_{b,\ h_i,\ w_i,\ c} = \begin{cases} \text{dy}_{b,\ h_o,\ w_o,\ c}, & \text{if } (h_i,\ w_i) = \displaystyle \arg\max_{(h,w)\in\mathcal{W}(h_o,w_o)} \text{input}_{b,\ h,\ w,\ c}, \\ 0, & \text{otherwise}. \end{cases}\end{split}\]其中,\(\mathcal{W}(h_o, w_o)\) 表示输出位置 \((h_o, w_o)\) 对应的池化窗口区域。窗口像素位置 \((h, w)\) 可表示为:
\[h = h_o \cdot \text{stride}_h - \text{pad}_u + \Delta h\]\[w = w_o \cdot \text{stride}_w - \text{pad}_l + \Delta w\]\[\Delta h \in [0,\ \text{win}_h - 1], \qquad \Delta w \in [0,\ \text{win}_w - 1]\]并且仅当采样点落在输入有效范围内时会被考虑:
\[0 \le h < \text{in}_h, \qquad 0 \le w < \text{in}_w.\]
- 实现细节说明:
前向池化使用窗口 \(\text{win}_h \times \text{win}_w\),步长为 \(\text{stride}_h\), \(\text{stride}_w\),并且在边界处使用 pad(pad_u, pad_l)。
反向传播时,输出梯度 tensor(即需要写入的输入梯度)在每个 batch 开始前先被初始化为 0(代码中有一次整体清零)。
对于每个输出像素 \((h_o,w_o)\) 以及每个通道 c:
在对应的输入窗口中找到前向最大值的位置 \((h^*,w^*)\);
将上游梯度 \(\text{dy}_{b,h_o,w_o,c}\) 累加到该位置:\(\text{output}_{b,h^*,w^*,c} \mathrel{+}= \text{dy}_{b,h_o,w_o,c}\)。
其他位置梯度保持 0。
- 输入:
input - 输入张量指针,采用 NHWC 格式,形状为 \([batch,\ in\_h,\ in\_w,\ channel]\)
dy - 上游梯度张量指针,采用 NHWC 格式,形状为 \([batch,\ output\_h,\ output\_w,\ channel]\)
in_w - 输入张量的宽度 (W)
in_h - 输入张量的高度 (H)
win_w - 池化窗口的宽度,即窗口在 W 方向的大小
win_h - 池化窗口的高度,即窗口在 H 方向的大小
output_w - 输出特征图的宽度
output_h - 输出特征图的高度
batch - 批次大小,即输入中的 batch 数
channel - 通道数 C ,每个池化位置都分别对 C 个通道独立执行最大池化与裁剪
stride_w - 池化窗口在 W 方向的步长
stride_h - 池化窗口在 H 方向的步长
pad_l - 输入特征图左侧的填充大小
pad_u - 输入特征图上侧的填充大小
minf - 输出结果的下界值。池化结果会执行 \(\max(v,\ \text{minf})\)
maxf - 输出结果的上界值。池化结果会执行 \(\min(v,\ \text{maxf})\)
core_mask - 核心掩码,指定使用的计算核心
- 输出:
output - 输出张量指针,采用 NHWC 格式,形状为 \([batch,\ in\_h,\ in\_w,\ channel]\)。
- 支持平台:
FT78NEMT7004备注
FT78NE 支持fp32, fp64
MT7004 支持fp16, fp32
调用时将除 core_mask 外的参数打包通过 long long params 数组传入,顺序为: input, dy, output, in_w, in_h, win_w, win_h, output_w, output_h, batch, channel, stride_w, stride_h, pad_l, pad_u, minf, maxf
共享存储版本:
-
void hp_maxpool_grad_s(long long *params, int core_mask)
-
void fp_maxpool_grad_s(long long *params, int core_mask)
-
void dp_maxpool_grad_s(long long *params, int core_mask)
C调用示例:
1//FT78NE示例 2#include <stdio.h> 3 4int main(int argc, char* argv[]) { 5 double* input_ptr = (double*)0xA0000000; 6 double* dy_ptr = (double*)0xB0000000; 7 double* output_ptr = (double*)0xC0000000; 8 double* check_ptr = (double*)0xD0000000; 9 int in_w = gin_w; 10 int in_h = gin_h; 11 int win_w = 6; 12 int win_h = 6; 13 int batch = gbatch; 14 int channel = 2; 15 int stride_w = 4; 16 int stride_h = 4; 17 int pad_l = 1; 18 int pad_u = 1; 19 double minf = 0.0f; 20 double maxf = 50.0f; 21 22 // 根据标准公式计算输出尺寸 23 int dividor = in_w + pad_l*2 - win_w; 24 int output_w = (dividor + stride_w - 1) / stride_w + 1; 25 int dividor2 = in_h + pad_u*2 - win_h; 26 int output_h = (dividor2 + stride_h - 1) / stride_h + 1; 27 28 long long params[17]; 29 params[0] = (long long)input_ptr; 30 params[1] = (long long)dy_ptr; 31 params[2] = (long long)output_ptr; 32 params[3] = (long long)in_w; 33 params[4] = (long long)in_h; 34 params[5] = (long long)win_w; 35 params[6] = (long long)win_h; 36 params[7] = (long long)output_w; 37 params[8] = (long long)output_h; 38 params[9] = (long long)batch; 39 params[10] = (long long)channel; 40 params[11] = (long long)stride_w; 41 params[12] = (long long)stride_h; 42 params[13] = (long long)pad_l; 43 params[14] = (long long)pad_u; 44 params[15] = (long long)&minf; //注意这里传指针,不能直接强制转换成long long 45 params[16] = (long long)&maxf; 46 int core_mask = 0x0f; 47 fp_maxpool_grad_s(params, core_mask); 48 return 0; 49}
私有存储版本:
-
void hp_maxpool_grad_p(long long *params)
-
void fp_maxpool_grad_p(long long *params)
-
void dp_maxpool_grad_p(long long *params)
C调用示例:
1//FT78NE示例 2#include <stdio.h> 3 4int main(int argc, char* argv[]) { 5 double* input_ptr = (double*)0xA0000000; 6 double* dy_ptr = (double*)0xB0000000; 7 double* output_ptr = (double*)0xC0000000; 8 double* check_ptr = (double*)0xD0000000; 9 int in_w = gin_w; 10 int in_h = gin_h; 11 int win_w = 6; 12 int win_h = 6; 13 int batch = gbatch; 14 int channel = 2; 15 int stride_w = 4; 16 int stride_h = 4; 17 int pad_l = 1; 18 int pad_u = 1; 19 double minf = 0.0f; 20 double maxf = 50.0f; 21 22 // 根据标准公式计算输出尺寸 23 int dividor = in_w + pad_l*2 - win_w; 24 int output_w = (dividor + stride_w - 1) / stride_w + 1; 25 int dividor2 = in_h + pad_u*2 - win_h; 26 int output_h = (dividor2 + stride_h - 1) / stride_h + 1; 27 28 long long params[17]; 29 params[0] = (long long)input_ptr; 30 params[1] = (long long)dy_ptr; 31 params[2] = (long long)output_ptr; 32 params[3] = (long long)in_w; 33 params[4] = (long long)in_h; 34 params[5] = (long long)win_w; 35 params[6] = (long long)win_h; 36 params[7] = (long long)output_w; 37 params[8] = (long long)output_h; 38 params[9] = (long long)batch; 39 params[10] = (long long)channel; 40 params[11] = (long long)stride_w; 41 params[12] = (long long)stride_h; 42 params[13] = (long long)pad_l; 43 params[14] = (long long)pad_u; 44 params[15] = (long long)&minf; //注意这里传指针,不能直接强制转换成long long 45 params[16] = (long long)&maxf; 46 fp_maxpool_grad_p(params); 47 return 0; 48}